Find and analyse drought events#

import sys
import os
import glob
import xarray as xr
from functools import partial
import datetime
import numpy as np
import plotly.graph_objects as go
import dask.array as da
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import label, generate_binary_structure
import geopandas as gpd
import pandas as pd
from scipy.ndimage import label, generate_binary_structure
import hvplot.xarray  # to plot xarray with hvplot
import cartopy.crs as ccrs

Load data function#

def get_spi_dataset(acc_period: str = 1, years: list = [2020]):
    data_root_folder = '/data1/drought_dataset/spi/'
    spi_folder = os.path.join(data_root_folder, f'spi{acc_period}')
    spi_paths = []

    for year in years:
        spi_paths.extend(sorted(glob.glob(
            f'{data_root_folder}spi{acc_period}/SPI{acc_period}_gamma_global_era5_moda_ref1991to2020_{year}*.nc')))

    return xr.open_mfdataset(spi_paths, chunks={'time': "auto"}, concat_dim="time", combine='nested', parallel=False)


def get_spei_dataset(acc_period: str = 1, years: list = [2020]):
    data_root_folder = '/data1/drought_dataset/spei/'
    spi_folder = os.path.join(data_root_folder, f'spi{acc_period}')
    spi_paths = []

    for year in years:
        spi_paths.extend(sorted(glob.glob(
            f'{data_root_folder}spei{acc_period}/SPEI{acc_period}_genlogistic_global_era5_moda_ref1991to2020_{year}*.nc')))

    return xr.open_mfdataset(spi_paths, chunks={'time': "auto"}, concat_dim="time", combine='nested', parallel=False)


def mask_invalid_values(ds, variable, value=-9999):
    ds[variable] = ds[variable].where(ds[variable] != value, np.nan)
    return ds


def subset_region(dataset, variable, bbox):
    # data = dataset.sel(time=np.datetime64(time), method='nearest')[variable]

    # Define the geographical boundaries for Madagascar
    lat_bounds = [bbox[1], bbox[3]]  # from south to north
    lon_bounds = [bbox[0], bbox[2]]  # from west to east

    # Check for NaN values in latitude and longitude coordinates
    lat_nan = dataset['lat'].isnull().any()
    lon_nan = dataset['lon'].isnull().any()

    # Handle NaN values if they exist
    if lat_nan:
        dataset = dataset.dropna(dim='lat', how='all')
    if lon_nan:
        dataset = dataset.dropna(dim='lon', how='all')

    # Ensure no NaN values in the data itself
    dataset = dataset.fillna(np.nan)  # or use another appropriate method like interpolation

    # Ensure the lat/lon bounds are within the data's range
    lat_min, lat_max = dataset['lat'].min().item(), dataset['lat'].max().item()
    lon_min, lon_max = dataset['lon'].min().item(), dataset['lon'].max().item()

    if lat_bounds[0] < lat_min or lat_bounds[1] > lat_max or lon_bounds[0] < lon_min or lon_bounds[1] > lon_max:
        raise ValueError("The specified latitude/longitude bounds are outside the range of the dataset.")

    # Subset the data using where and dropna
    dataset = dataset.where(
        (dataset['lat'] >= lat_bounds[0]) & (dataset['lat'] <= lat_bounds[1]) &
        (dataset['lon'] >= lon_bounds[0]) & (dataset['lon'] <= lon_bounds[1]),
        drop=True
    )

    # return xr.Dataset(data)
    return dataset
    
def get_spei_significance_dataset(variable='SPEI1', year=2020):
    data_root_folder='/data1/drought_dataset/spei/'
    quality_paths = []
    for month in range(1, 13):
        month_str = f'{month:02d}'
        quality_paths.append(f'{data_root_folder}{variable.lower()}/parameter/{variable}_significance_global_era5_moda_{year}{month_str}_ref1991to2020.nc')
    return xr.open_mfdataset(quality_paths, concat_dim="time", combine='nested', parallel=False)

def get_spi_significance_dataset(variable='SPI1', year=2020):
    data_root_folder='/data1/drought_dataset/spi/'
    quality_paths = []
    for month in range(1, 13):
        month_str = f'{month:02d}'
        quality_paths.append(f'{data_root_folder}{variable.lower()}/parameter/{variable}_significance_global_era5_moda_{year}{month_str}_ref1991to2020.nc')
    return xr.open_mfdataset(quality_paths, concat_dim="time", combine='nested', parallel=False)

Load dataset#

# Load dataset
spei_data = get_spei_dataset(acc_period=12, years=list(range(1940, 2025)))
spei48_region = mask_invalid_values(spei_data, variable='SPEI12')

Filter dataset for specific bounding box#

# Get a subset of the dataset for a bbox
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
world = world.to_crs(epsg=4326)
# country_list = world['name'].unique().tolist()
# country_list.sort()
# country_shape = world[world['name'] == 'Kenya']
country_shape = world[world['name'] == 'S. Sudan']
spei_data = spei_data.rio.write_crs("EPSG:4326", inplace=True)

spei_data_country = spei48_region.rio.clip(country_shape.geometry, world.crs, drop=True)
/tmp/ipykernel_2051206/1990148899.py:2: FutureWarning: The geopandas.dataset module is deprecated and will be removed in GeoPandas 1.0. You can get the original 'naturalearth_lowres' data from https://www.naturalearthdata.com/downloads/110m-cultural-vectors/.
  world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
spei = spei_data_country['SPEI12']
spei
<xarray.DataArray 'SPEI12' (time: 1009, lat: 34, lon: 45)> Size: 12MB
dask.array<getitem, shape=(1009, 34, 45), dtype=float64, chunksize=(1, 34, 45), chunktype=numpy.ndarray>
Coordinates:
  * time         (time) datetime64[ns] 8kB 1940-01-01T06:00:00 ... 2024-01-01...
  * lon          (lon) float64 360B 24.25 24.5 24.75 25.0 ... 34.75 35.0 35.25
  * lat          (lat) float64 272B 12.0 11.75 11.5 11.25 ... 4.5 4.25 4.0 3.75
    spatial_ref  int64 8B 0
Attributes:
    long_name:  Standardized Drought Index (SPEI12)
    units:      -
spei.hvplot(
    clim=(-8,8),
    groupby="time",
    widget_type="scrubber", 
    widget_location="bottom", 
    projection=ccrs.PlateCarree(), 
    coastline='10m',
    cmap='BrBG',
    features=['borders']
)

Analyse each month and find if there was a drought while at the same time classify the conditions for the whole region. E.g. there was a severe drought in a time point if for least a minimum number of grid points SPEI < -1.5#

np.count_nonzero(~np.isnan(spei[49]).values)
# spei[49]
818

Setup drought severity classification function and classes#

import xarray as xr
import numpy as np

def classify_drought_severity(spei, classes, conditions, threshold=50):
    """
    Classifies drought severity based on SPEI values and counts grid points in each class.

    Parameters:
    - spei: An xarray DataArray containing SPEI values (dimensions: time, lat, lon).
    - classes: A list of class names (e.g., ['Extreme Drought', 'Severe Drought', ...]).
    - conditions: A list of conditions corresponding to each class.
    - threshold: Minimum number of grid points required to classify a time step into a specific class.

    Returns:
    - result_df: A pandas DataFrame with counts of grid points in each class for each time step,
      including a 'Final Classification' column.
    """

    # Count the number of grid points in each condition for each time step
    counts = [condition.sum(dim=['lat', 'lon']) for condition in conditions]

    # Combine counts along a new dimension called 'class'
    counts_concat = xr.concat(counts, dim=pd.Index(classes, name="class"))

    # Convert to DataFrame
    counts_df = counts_concat.to_dataframe(name='count').reset_index()

    # Pivot the DataFrame to have classes as columns
    result_df = counts_df.pivot(index='time', columns='class', values='count').fillna(0)

    # Determine the final classification for each time step based on the threshold
    def classify_row(row):
        for class_name in classes:
            if row[class_name] >= threshold:
                return class_name
        return 'No Data'  # If no class meets the threshold

    result_df['Final Classification'] = result_df.apply(classify_row, axis=1)

    return result_df

# Example usage
# Load the dataset (assuming it's already in xarray format)
# ds = xr.open_dataset('your_dataset.nc')  # Uncomment if loading from file
# spei = ds['SPEI']  # Replace 'SPEI' with your actual variable name

# Define the conditions and corresponding classes
conditions = [
    spei < -2, # 'Extremely dry'
    (spei >= -2) & (spei < -1.5), # 'Severely dry'
    (spei >= -1.5) & (spei < -1), # 'Moderately dry'
    (spei >= -1) & (spei < 0), # 'Mildly dry'
    (spei >= 0) & (spei <= 1), # 'Mildly wet'
    (spei >= 1) & (spei <= 1.5), # 'Moderately wet'
    (spei >= 1.5) & (spei <= 2), # 'Severely wet'
    spei > 2 # 'Extremely wet'
]
classes = ['Extremely Dry', 
           'Severely Dry', 
           'Moderately Dry', 
           'Mildly Dry', 
           'Mildly Wet', 
           'Moderately Wet', 
           'Severely Wet', 
           'Extremely Wet']

Classify months in spei#

# Get the result DataFrame
result_df = classify_drought_severity(spei, classes, conditions)
result_df = result_df.reset_index()
# Output the result
result_df
class time Extremely Dry Extremely Wet Mildly Dry Mildly Wet Moderately Dry Moderately Wet Severely Dry Severely Wet Final Classification
0 1940-01-01 06:00:00 0 0 0 0 0 0 0 0 No Data
1 1940-02-01 06:00:00 0 0 0 0 0 0 0 0 No Data
2 1940-03-01 06:00:00 0 0 0 0 0 0 0 0 No Data
3 1940-04-01 06:00:00 0 0 0 0 0 0 0 0 No Data
4 1940-05-01 06:00:00 0 0 0 0 0 0 0 0 No Data
... ... ... ... ... ... ... ... ... ... ...
1004 2023-09-01 06:00:00 756 0 5 0 13 0 44 0 Extremely Dry
1005 2023-10-01 06:00:00 723 0 14 0 27 0 54 0 Extremely Dry
1006 2023-11-01 06:00:00 644 0 33 2 41 0 98 0 Extremely Dry
1007 2023-12-01 06:00:00 647 0 34 1 38 0 98 0 Extremely Dry
1008 2024-01-01 06:00:00 652 0 40 1 39 0 86 0 Extremely Dry

1009 rows × 10 columns

Generate barplot for the dataset to visuallize drought events

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import matplotlib.dates as mdates

# Map classifications to colors
class_color_map = {
    'Extremely Dry': '#291700',  # Red
    'Severely Dry': '#8F5100',   # Orange
    'Moderately Dry': '#FF961F', # Yellow
    'Mildly Dry': '#FFBF69',     # Light Green
    'Mildly Wet': '#A1F7D5',           # Green
    'Moderately Wet': '#43EFAA',   # Dark Green
    'Severely Wet': '#0D965F',
    'Extremely Wet': '#074B30',
    'No Data': '#cccccc'  # Gray
}

# Map the classifications to colors
result_df['Color'] = result_df['Final Classification'].map(class_color_map)

# Create the plot
plt.figure(figsize=(12, 4))  # Adjust the width and height of the plot

# Plot bars
plt.bar(result_df['time'], 1, color=result_df['Color'], width=60, align='edge')  # Adjust width for visibility

# Customize x-axis and y-axis
plt.gca().yaxis.set_visible(False)  # Hide y-axis

# Set x-axis major locator and formatter to show only yearly ticks
plt.gca().xaxis.set_major_locator(mdates.YearLocator())  # Place ticks at yearly intervals
plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y'))  # Format x-axis labels to show only year

# Set x-axis limits
plt.xlim(pd.Timestamp(result_df.time.min()), pd.Timestamp(result_df.time.max()))

# Rotate x-axis labels for better readability
plt.xticks(rotation=90)

# Label the x-axis
plt.xlabel('Time')

# Set the title of the plot
plt.title('Drought Classification Over Time')

# Add legend
handles = [plt.Line2D([0], [0], color=color, lw=4) for color in class_color_map.values()]
labels = list(class_color_map.keys())
plt.legend(handles, labels, title='Drought Classification', bbox_to_anchor=(1.05, 1), loc='upper left')

# Adjust layout for better fit
plt.tight_layout()

# Show the plot
plt.show()
../../_images/d3f862b5cbd84203036d598fe45de20522fb9a1269100126eb0e9b4c776b6aa5.png
import plotly.graph_objs as go
import pandas as pd
import numpy as np


# Map classifications to colors
class_color_map = {
    'Extremely Dry': '#291700',  # Red
    'Severely Dry': '#8F5100',   # Orange
    'Moderately Dry': '#FF961F', # Yellow
    'Mildly Dry': '#FFBF69',     # Light Green
    'Mildly Wet': '#A1F7D5',           # Green
    'Moderately Wet': '#43EFAA',   # Dark Green
    'Severely Wet': '#0D965F',
    'Extremely Wet': '#074B30',
    'No Data': '#cccccc'  # Gray
}

# Map the classifications to colors
result_df['Color'] = result_df['Final Classification'].map(class_color_map)

# Create the plot
fig = go.Figure()

legend_order = [
    'Extremely Dry', 'Severely Dry', 'Moderately Dry', 'Mildly Dry',
    'Mildly Wet', 'Moderately Wet', 'Severely Wet', 'Extremely Wet',
    'No Data'
]

# Add bars
for lbl in legend_order:
    fig.add_trace(go.Bar(
        x=result_df['time'].loc[result_df['Final Classification']==lbl],
        y=[1] * len(result_df),
        name=lbl,
        marker=dict(color=result_df['Color'].loc[result_df['Final Classification']==lbl], line=dict(width=0)),
        width=60 * 24 * 60 * 60 * 1000,  # Width in milliseconds
        orientation='v',  # Vertical bars
        # name='Drought Classification'
    ))

x_min = result_df['time'].min()
x_max = result_df['time'].max()

# Update x-axis and y-axis
fig.update_xaxes(
    title_text='Time (Months)',
    tickformat='%Y',  # Format x-axis labels to show only year
    tickangle=90,  # Rotate x-axis labels
    rangeslider_visible=False,  # Hide the range slider
    type='date',# Ensure x-axis is treated as dates
    range=[x_min, x_max]
)

fig.update_yaxes(
    visible=False  # Hide y-axis
)

# Add legend
fig.update_layout(
    title='Drought Classification Over Time',
    legend_title='Drought Classification',
    legend=dict(
        x=1.05,  # Positioning the legend to the right of the plot
        y=1,
        orientation='v',
        traceorder='normal'  # Ensure legend entries are in the order they appear in the plot
    ),
    margin=dict(l=50, r=200, t=50, b=50),
    paper_bgcolor='white',
    plot_bgcolor='white',
    font=dict(
        color='#2a3f5f',
        family='sans-serif'
    ),
)

# Show the plot
fig.show()

Setup function to detect continuous periods of a condition#

def detect_continuous_periods_with_dates(df, binary_col, date_col, min_sep=1):
    """
    Detects continuous periods of 1s in a binary vector within a DataFrame and returns a new DataFrame
    with the start date, end date, and duration of each period.
    
    Parameters:
    - df: Input DataFrame containing the binary vector and dates.
    - binary_col: Column name for the binary vector (0s and 1s).
    - date_col: Column name for the corresponding dates.
    - min_sep: Minimum number of continuous 0s required to separate periods of 1s.
    
    Returns:
    - periods_df: A DataFrame with 'Start Date', 'End Date', and 'Duration' columns.
    """
    
    # Ensure binary_col is binary (0s and 1s)
    assert df[binary_col].isin([0, 1]).all(), "The binary column must contain only 0s and 1s."
    
    # Detect transitions in the binary column
    transitions = df[binary_col].diff().fillna(0)
    
    # Find where the vector changes from 0 to 1 (start of 1s) and 1 to 0 (end of 1s)
    start_ones = transitions == 1
    end_ones = transitions == -1
    
    # Get the indices of these transitions
    start_indices = start_ones[start_ones].index
    end_indices = end_ones[end_ones].index
    
    # If the series starts with 1s, add a start at the beginning
    if df[binary_col].iloc[0] == 1:
        start_indices = pd.Index([df.index[0]]).append(start_indices)
    
    # If the series ends with 1s, add an end at the end
    if df[binary_col].iloc[-1] == 1:
        end_indices = end_indices.append(pd.Index([df.index[-1]]))
    
    # Ensure indices are aligned
    assert len(start_indices) == len(end_indices), "Mismatched start and end periods."
    
    # Filter out periods that are too close to each other based on min_sep
    valid_periods = []
    last_end = -min_sep - 1  # Initialize last_end to be far enough back
    
    for start, end in zip(start_indices, end_indices):
        if start - last_end >= min_sep:
            valid_periods.append((start, end))
        last_end = end
    
    # Create a new DataFrame for the detected periods
    periods = []
    for start, end in valid_periods:
        start_date = df.loc[start, date_col]
        end_date = df.loc[end, date_col]
        duration = (end_date.year - start_date.year) * 12 + end_date.month - start_date.month + 1  # Duration in months
        periods.append({'Start Date': start_date, 'End Date': end_date, 'Duration': duration})
    
    periods_df = pd.DataFrame(periods)
    return periods_df

Convert the timeline to a binary vector.#

Every dry condition is marked as drought and everything else as no drought. A minimum separation of 2 months with no drought is regarded as no change.

min_sep = 2  # Minimum separation of 2 zeros to consider periods distinct

result_df['class'] = np.where((result_df['Final Classification']=='Extremely Dry')|
                              (result_df['Final Classification']=='Severely Dry')|
                              (result_df['Final Classification']=='Moderately Dry')|
                              (result_df['Final Classification']=='Mildly Dry'), 1, 0)

Find the continuous periods and calculate their duration#

periods_df = detect_continuous_periods_with_dates(result_df, binary_col='class', date_col='time', min_sep=min_sep)
periods_df
Start Date End Date Duration
0 1941-04-01 06:00:00 1944-05-01 06:00:00 38
1 1944-08-01 06:00:00 1946-10-01 06:00:00 27
2 1963-09-01 06:00:00 1963-11-01 06:00:00 3
3 1964-07-01 06:00:00 1964-08-01 06:00:00 2
4 1965-08-01 06:00:00 1967-05-01 06:00:00 22
5 1967-12-01 06:00:00 1968-08-01 06:00:00 9
6 1984-10-01 06:00:00 1985-03-01 06:00:00 6
7 1988-04-01 06:00:00 1988-07-01 06:00:00 4
8 1990-05-01 06:00:00 1991-04-01 06:00:00 12
9 1993-10-01 06:00:00 1994-09-01 06:00:00 12
10 1997-08-01 06:00:00 1998-02-01 06:00:00 7
11 1998-11-01 06:00:00 1999-10-01 06:00:00 12
12 2000-04-01 06:00:00 2001-11-01 06:00:00 20
13 2002-02-01 06:00:00 2002-03-01 06:00:00 2
14 2002-05-01 06:00:00 2024-01-01 06:00:00 261

Plot all the event durations and find the 75% percentile to find drought events with an anomalous duration#

def plot_duration_bar_plot(data, percentile=75):
    percentile_9_duration = np.percentile(data.Duration, 90)
    percentile_1_duration = np.percentile(data.Duration, 10)
    median_duration = data.Duration.median()
    
    
    # Create the plot
    plt.figure(figsize=(10, 6))
    
    # Create bars for each event
    bars = plt.bar(data.index, data['Duration'], color='skyblue', edgecolor='black')
    
    # Add a dashed red line for the average duration
    plt.axhline(y=percentile_9_duration, color='red', linestyle='--', linewidth=2, label=f'{90} percentile of durations: {percentile_9_duration:.2f} months')
    plt.axhline(y=percentile_1_duration, color='green', linestyle='--', linewidth=2, label=f'{10} percentile of durations: {percentile_1_duration:.2f} months')
    plt.axhline(y=median_duration, color='blue', linestyle='--', linewidth=2, label=f'Median duration: {median_duration:.2f} months')
    
    # Labeling the x-axis ticks with the start and end dates
    xticks_labels = [f"{start.strftime('%Y-%m')} - {end.strftime('%Y-%m')}" for start, end in zip(data['Start Date'], data['End Date'])]
    plt.xticks(ticks=np.arange(len(data.index)), labels=xticks_labels)
    
    # Label axes
    plt.xlabel('Events')
    plt.ylabel('Duration (Months)')
    plt.title('Event Durations with Start and End Dates')
    
    # Add legend
    plt.legend()
    
    # Rotate x-axis labels for better readability
    plt.xticks(rotation=45, ha='right')
    
    # Adjust layout for better fit
    plt.tight_layout()
    
    # Show the plot
    plt.show()
plot_duration_bar_plot(periods_df)
../../_images/69f711287711fe93456608eecebf360ff8cb764ec30d30960f0f02598b28ef3b.png
def plot_duration_bar_plot(data, percentile=75):
    percentile_9_duration = np.percentile(data.Duration, 90)
    percentile_1_duration = np.percentile(data.Duration, 10)
    median_duration = data.Duration.median()

    # Generate x-axis labels based on the dates
    x_labels = [f"{start.strftime('%Y-%m')} - {end.strftime('%Y-%m')}" for start, end in zip(data['Start Date'], data['End Date'])]

    # Create a numerical x-axis for the plot
    x_numeric = list(range(len(x_labels)))

    # Create bars for each event
    bar = go.Bar(
        x=x_numeric,
        y=data['Duration'],
        marker=dict(color='skyblue', line=dict(color='black', width=1)),
        name='Period',
        # text=x_labels,
        # textposition='auto'
    )
    
    # Define the x-axis range for the lines
    line_x_values = [x_numeric[0] - 1, x_numeric[-1] + 1]  # Extend beyond the first and last data point
    
    # Create lines for percentiles and median
    percentile_9_line = go.Scatter(
        x=line_x_values,
        y=[percentile_9_duration, percentile_9_duration],
        mode='lines',
        line=dict(color='red', dash='dash'),
        name=f'90th percentile: {percentile_9_duration:.2f} months'
    )
    
    percentile_1_line = go.Scatter(
        x=line_x_values,
        y=[percentile_1_duration, percentile_1_duration],
        mode='lines',
        line=dict(color='green', dash='dash'),
        name=f'10th percentile: {percentile_1_duration:.2f} months'
    )
    
    median_line = go.Scatter(
        x=line_x_values,
        y=[median_duration, median_duration],
        mode='lines',
        line=dict(color='blue', dash='dash'),
        name=f'Median: {median_duration:.2f} months'
    )
    
    # Create the layout
    layout = go.Layout(
        title='Event Durations with Start and End Dates',
        xaxis=dict(
            title='Events',
            tickangle=-45,
            tickmode='array',
            tickvals=x_numeric,
            ticktext=x_labels,
            range=[x_numeric[0] - 1, x_numeric[-1] + 1],  # Extend x-axis range
        ),
        yaxis=dict(title='Duration (Months)'),
        barmode='group',
        legend=dict(x=0.5,xanchor="center", y=-1, orientation='h'),
        margin=dict(l=50, r=50, t=50, b=100),
        paper_bgcolor='white',  # Transparent background for the entire paper
        plot_bgcolor='white',
        font=dict(
            color='#2a3f5f',
            family='sans-serif'
            ),
    )
    
    # Create the figure and add the traces
    fig = go.Figure(data=[bar, percentile_9_line, percentile_1_line, median_line], layout=layout)
    
    # Show the plot
    fig.show()
plot_duration_bar_plot(periods_df)

Calculate area percentage for each class for each month and aggregate for each event#

def calculate_area_percentage(monthly_data, periods):
    columns_to_use = ['Extremely Dry',
                     'Severely Dry', 
                     'Moderately Dry',
                     'Mildly Dry', 
                     'Mildly Wet', 
                     'Moderately Wet',
                     'Severely Wet',
                     'Extremely Wet']
    
    new_columns = ['Extremely Dry %',
                   'Severely Dry %', 
                   'Moderately Dry %',
                   'Mildly Dry %',
                   'Mildly Wet %', 
                   'Moderately Wet %',
                   'Severely Wet %',
                   'Extremely Wet %']
    
    rows = []
    for i, row in periods.iterrows():
        start_date = row['Start Date']
        end_date = row['End Date']
        df = monthly_data.loc[(monthly_data.time >= start_date) & (monthly_data.time <= end_date)]
        total = df[columns_to_use].sum(axis=1)
        # Calculate the percentage for each specified column
        df_percentage = df[columns_to_use].div(total, axis=0) * 100
        cols = {i[0]:i[1] for i in zip(columns_to_use, new_columns)}
        df_percentage.rename(columns=cols,inplace=-True)
        # Add the percentage columns back to the original dataframe, if needed
        df.loc[:, new_columns] = df_percentage
        rows.append(df[new_columns].mean(axis=0))
    new_df = pd.concat(rows, axis=1).T.reset_index(drop=True)
    new_df['Start Date'] = periods['Start Date']
    new_df['End Date'] = periods['End Date']
    return new_df
percentages = calculate_area_percentage(result_df, periods_df)
percentages
class Extremely Dry % Severely Dry % Moderately Dry % Mildly Dry % Mildly Wet % Moderately Wet % Severely Wet % Extremely Wet % Start Date End Date
0 1.209626 2.843907 6.144640 26.647150 38.138592 7.933342 6.157509 10.925235 1941-04-01 06:00:00 1944-05-01 06:00:00
1 1.616409 2.598931 5.030336 31.694286 48.691479 5.637055 2.354433 2.377071 1944-08-01 06:00:00 1946-10-01 06:00:00
2 0.040750 0.000000 0.407498 7.538712 29.951100 17.807661 15.729421 28.524857 1963-09-01 06:00:00 1963-11-01 06:00:00
3 3.361858 1.222494 1.589242 5.867971 47.188264 32.273839 7.212714 1.283619 1964-07-01 06:00:00 1964-08-01 06:00:00
4 3.778617 1.828184 3.217382 26.605912 44.654368 12.736164 5.140031 2.039342 1965-08-01 06:00:00 1967-05-01 06:00:00
5 2.811736 1.725075 3.558816 11.735941 24.789459 16.503667 13.922847 24.952459 1967-12-01 06:00:00 1968-08-01 06:00:00
6 0.081500 0.651997 1.976365 9.881826 50.040750 31.744091 4.095355 1.528117 1984-10-01 06:00:00 1985-03-01 06:00:00
7 0.061125 0.000000 0.000000 10.849633 77.567237 8.863081 2.353301 0.305623 1988-04-01 06:00:00 1988-07-01 06:00:00
8 0.061125 0.000000 0.478810 14.343928 73.471883 10.177262 1.375306 0.091687 1990-05-01 06:00:00 1991-04-01 06:00:00
9 0.937245 2.006927 3.382233 17.817848 58.842706 16.554605 0.458435 0.000000 1993-10-01 06:00:00 1994-09-01 06:00:00
10 0.174642 0.157178 1.432064 14.477820 55.483758 27.209221 1.065316 0.000000 1997-08-01 06:00:00 1998-02-01 06:00:00
11 0.081500 0.804808 3.636919 10.951508 31.081907 48.268134 5.175224 0.000000 1998-11-01 06:00:00 1999-10-01 06:00:00
12 0.201711 0.482885 1.167482 17.732274 71.479218 8.655257 0.281174 0.000000 2000-04-01 06:00:00 2001-11-01 06:00:00
13 0.000000 0.000000 0.000000 4.400978 83.007335 12.530562 0.061125 0.000000 2002-02-01 06:00:00 2002-03-01 06:00:00
14 6.823015 12.851643 16.260574 41.335282 20.840476 1.189707 0.494150 0.205154 2002-05-01 06:00:00 2024-01-01 06:00:00
percentages['Dry'] = percentages.loc[:, ['Extremely Dry %', 'Severely Dry %', 'Moderately Dry %', 'Mildly Dry %']].sum(axis=1)
# new_order = ['time',
#              'Extremely Dry',
#              'Extremely Dry %',
#              'Severely Dry',
#              'Severely Dry %',
#              'Moderately Dry',
#              'Moderately Dry %',
#              'Mildly Dry',
#              'Mildly Dry %',
#              'Mildly Wet',
#              'Mildly Wet %',
#              'Moderately Wet',
#              'Moderately Wet %',
#              'Severely Wet',
#              'Severely Wet %',
#              'Extremely Wet',
#              'Extremely Wet %',
#              'Color',
#              'class']
# df1 = df[new_order]
percentages
class Extremely Dry % Severely Dry % Moderately Dry % Mildly Dry % Mildly Wet % Moderately Wet % Severely Wet % Extremely Wet % Start Date End Date Dry
0 1.209626 2.843907 6.144640 26.647150 38.138592 7.933342 6.157509 10.925235 1941-04-01 06:00:00 1944-05-01 06:00:00 36.845322
1 1.616409 2.598931 5.030336 31.694286 48.691479 5.637055 2.354433 2.377071 1944-08-01 06:00:00 1946-10-01 06:00:00 40.939962
2 0.040750 0.000000 0.407498 7.538712 29.951100 17.807661 15.729421 28.524857 1963-09-01 06:00:00 1963-11-01 06:00:00 7.986960
3 3.361858 1.222494 1.589242 5.867971 47.188264 32.273839 7.212714 1.283619 1964-07-01 06:00:00 1964-08-01 06:00:00 12.041565
4 3.778617 1.828184 3.217382 26.605912 44.654368 12.736164 5.140031 2.039342 1965-08-01 06:00:00 1967-05-01 06:00:00 35.430096
5 2.811736 1.725075 3.558816 11.735941 24.789459 16.503667 13.922847 24.952459 1967-12-01 06:00:00 1968-08-01 06:00:00 19.831568
6 0.081500 0.651997 1.976365 9.881826 50.040750 31.744091 4.095355 1.528117 1984-10-01 06:00:00 1985-03-01 06:00:00 12.591687
7 0.061125 0.000000 0.000000 10.849633 77.567237 8.863081 2.353301 0.305623 1988-04-01 06:00:00 1988-07-01 06:00:00 10.910758
8 0.061125 0.000000 0.478810 14.343928 73.471883 10.177262 1.375306 0.091687 1990-05-01 06:00:00 1991-04-01 06:00:00 14.883863
9 0.937245 2.006927 3.382233 17.817848 58.842706 16.554605 0.458435 0.000000 1993-10-01 06:00:00 1994-09-01 06:00:00 24.144254
10 0.174642 0.157178 1.432064 14.477820 55.483758 27.209221 1.065316 0.000000 1997-08-01 06:00:00 1998-02-01 06:00:00 16.241705
11 0.081500 0.804808 3.636919 10.951508 31.081907 48.268134 5.175224 0.000000 1998-11-01 06:00:00 1999-10-01 06:00:00 15.474735
12 0.201711 0.482885 1.167482 17.732274 71.479218 8.655257 0.281174 0.000000 2000-04-01 06:00:00 2001-11-01 06:00:00 19.584352
13 0.000000 0.000000 0.000000 4.400978 83.007335 12.530562 0.061125 0.000000 2002-02-01 06:00:00 2002-03-01 06:00:00 4.400978
14 6.823015 12.851643 16.260574 41.335282 20.840476 1.189707 0.494150 0.205154 2002-05-01 06:00:00 2024-01-01 06:00:00 77.270513
def plot_area_bar_plot(data, columns_to_sum=['Moderately Dry %',
                                             'Mildly Dry %',
                                             'Mildly Wet %',
                                             'Moderately Wet %',
                                             'Severely Wet %',
                                             'Extremely Wet %']):
    columns = [i for i in data.columns if '%' in i and i not in columns_to_sum]

    fig = go.Figure()
    x_axis_labels =  [f"{start.strftime('%Y-%m')} - {end.strftime('%Y-%m')}" for start, end in zip(data['Start Date'], data['End Date'])]

    # Adding bars for each category
    if columns_to_sum:
        fig.add_trace(go.Bar(
            x=x_axis_labels,
            y=data[columns_to_sum].sum(axis=1),
            name='Normal',
            marker=dict(color=class_color_map['Severely Wet'])
        ))
    for category in columns[::-1]:
        fig.add_trace(go.Bar(
            x=x_axis_labels,
            y=data[category],
            name=category[:-2],
            marker=dict(color=class_color_map[category[:-2]])
        ))
    
    # Updating the layout for stacked bar
    fig.update_layout(
        barmode='stack',  # This ensures the bars are stacked
        title='Area of each type of drought',
        xaxis=dict(title='Events',
                   tickangle=-45,  # Rotate the x-axis labels by -45 degrees
                   tickmode='array',
                   tickvals=x_axis_labels,
                   ticktext=x_axis_labels,),
        yaxis=dict(title='Percentage'),
        legend=dict(orientation='v',x=1, y=0.5),
        margin=dict(l=50, r=50, t=50, b=100),
        paper_bgcolor='white',  # Transparent background for the entire paper
        plot_bgcolor='white',
        font=dict(
            color='#2a3f5f',
            family='sans-serif'
            ),
    )
    
    # Show the plot
    fig.show()
plot_area_bar_plot(percentages, columns_to_sum=[])
plot_area_bar_plot(percentages)